Infant Mortality in Asian Countries: A Data-Driven Exploration¶

Introduction¶

Infant mortality, the death of infants within the first year of life, is a critical indicator of a population's health and well-being. Understanding the factors that contribute to infant mortality rates is crucial for policymakers, healthcare professionals, and researchers seeking to improve the health outcomes of vulnerable populations.

This project focuses on examining infant mortality in Asian countries, a region known for its diverse population, cultural richness, and varying levels of socioeconomic development. By investigating the factors associated with infant mortality in Asia, we aim to gain insights into the unique challenges and opportunities for improving infant health in this region.

The study utilizes a combination of cross-sectional and time series data to capture both the static and dynamic aspects of infant mortality. The dataset includes information on various factors such as adolescent birth rates, access to healthcare services, sanitation facilities, and socioeconomic indicators. By analyzing these variables, we can identify key drivers and patterns that contribute to variations in infant mortality rates across different countries in Asia.

In this project, we employ advanced data analysis techniques, including exploratory data analysis, regression modeling, time series analysis, and machine learning, to uncover important insights and develop predictive models. The ultimate goal is to provide evidence-based recommendations and interventions that can effectively reduce infant mortality rates and improve the health outcomes of infants and their families in Asian countries.

Through this project, we aim to contribute to the broader field of public health and inform policymakers, healthcare practitioners, and other stakeholders about effective strategies to address the complex issue of infant mortality in Asia. By understanding the underlying factors and trends, we can work towards creating targeted interventions and policies that have the potential to save countless infant lives and ensure a healthier future for the region.

Data¶

  • Mortality rate, infant (per 1,000 live births)

  • Physicians (per 1,000 people)

  • Nurses and midwives (per 1,000 people)

  • Births attended by skilled health staff (% of total)

  • Total alcohol consumption per capita (liters of pure alcohol, projected estimates, 15+ years of age)

  • Adolescent fertility rate (births per 1,000 women ages 15-19)

  • Current health expenditure (% of GDP)

  • Mortality rate, infant (per 1,000 live births)

  • People using at least basic sanitation services (% of population): Basic sanitation is defined as having access to facilities for the safe disposal of human waste (feces and urine), as well as having the ability to maintain hygienic conditions, through services such as garbage collection, industrial/hazardous waste management, and wastewater treatment and disposal.

  • People using safely managed drinking water services (% of population): Basic drinking water services is defined as drinking water from an improved source, provided collection time is not more than 30 minutes for a round trip.

  • People with basic handwashing facilities including soap and water (% of population): The percentage of population living in households that have a handwashing facility with soap and water at home. Handwashing facilities can consist of a sink with tap water, but can also include other devices that contain, transport or regulate the flow of water. Buckets with taps, tippy-taps and portable basins are all examples of handwashing facilities. Bar soap, liquid soap, powder detergent and soapy water all count as soap for monitoring purposes.

  • People using safely managed sanitation services (% of population) Population using an improved sanitation facility that is not shared with other households and where excreta are safely disposed of in situ or treated off site. Improved sanitation facilities include flush/pour flush to piped sewer systems, septic tanks or pit latrines; pit latrines with slabs (including ventilated pit latrines), and composting toilets.

In [52]:
import os
import pandas as pd

directory = "G:/Documents/Projects/InfantMortality/data"

csv_files = os.listdir(directory)

# Read each CSV file and assign it to a variable with the same name as the file
for file in csv_files:
    file_path = os.path.join(directory, file)
    df_name = os.path.splitext(file)[0]  # Extract the file name without extension
    
    # Read CSV file and assign to a variable with the same name as the file
    globals()[df_name] = pd.read_csv(file_path)
In [53]:
# Getting the names of all dataframes
df_names = [df_name.replace('.csv','') for df_name in csv_files]
In [54]:
# A function to unpivot a dataframe
def unpivot(dfname):
    df = globals()[dfname]
    
    df.drop(['Indicator Code'], axis = 1, inplace=True)
    
    df_reshaped = pd.melt(
    df,
    id_vars=['Country Name', 'Indicator Name', 'Country Code'],
    value_vars=df.loc[:,~df.columns.isin(['Country Name', 'Indicator Name', 'Country Code'])],
    var_name='Year',
    value_name=dfname
    )
    
    # Removing the constant Indicator column
    df_reshaped.drop('Indicator Name', axis = 1, inplace = True)
    
    #Rename the columns
    df_reshaped.rename(columns={'Country Name': 'Country'}, inplace=True)

    # Sort the DataFrame by Country and Year
    df_reshaped.sort_values(['Country', 'Year'], inplace=True)

    # Reset the index
    df_reshaped.reset_index(drop=True, inplace=True)
    
    return df_reshaped
In [55]:
for dfname in df_names:
    globals()[dfname] = unpivot(dfname)
In [56]:
# Check whether year is a int64 in all dataframes; if not, change it
for dfname in df_names:
    df = globals()[dfname]
    if df.Year.dtype != 'int64':
        globals()[dfname].Year = df.Year.astype('int64')
In [57]:
# Calculating global infant mortality rate
global_avg_mortality = infantMortalityRate.infantMortalityRate.mean()
In [58]:
df_names2 = df_names.copy()
df_names2.remove('infantMortalityRate')
merged_df = infantMortalityRate.copy()

# Iterate over the remaining DataFrames and merge them with the previously merged DataFrame
for dfname in df_names2:
    merged_df = pd.merge(merged_df, globals()[dfname], on=['Country','Year','Country Code'], how = 'outer')
In [59]:
country_metadata = pd.read_csv("G:/Documents/Projects/InfantMortality/metadata/country_metadata.csv", delimiter = ';')
In [60]:
economy_class = pd.read_excel("G:/Documents/Projects/InfantMortality/metadata/economy_class.xlsx")
economy_class.rename(columns = {'Code':'ISO-alpha3 Code'}, inplace = True)
economy_class.drop(['Lending category','Economy','Region'], axis = 1, inplace = True)
country_metadata = country_metadata.merge(economy_class, on = 'ISO-alpha3 Code', how = 'left')
In [61]:
asia_data = merged_df[merged_df['Country Code'].isin(country_metadata['ISO-alpha3 Code'])].copy()
# Adding income group and region to the data
country_metadata.rename(columns = {'ISO-alpha3 Code':'Country Code'}, inplace = True)
asia_data = pd.merge(asia_data, country_metadata[['Sub-region Name','Income group','Country Code']],
                     on = 'Country Code', how = 'left')
asia_data.rename(columns = {'Sub-region Name': 'Region'}, inplace = True)
In [62]:
# Confirming that we have all the Asian countries
len(asia_data.Country.unique()) == len(country_metadata['Country Code'])
Out[62]:
True
In [63]:
# removing rows where infant mortality is NA
asia_data.dropna(subset = ['infantMortalityRate'], inplace = True)
In [64]:
# Finding the years with least NAs
nonNA_count = asia_data.groupby('Year').count().sum(axis=1)
In [65]:
print(nonNA_count[nonNA_count > 450])
# I choose years 2000-2020
Year
2000    650
2001    543
2002    544
2003    550
2004    576
2005    664
2006    581
2007    582
2008    593
2009    594
2010    693
2011    581
2012    592
2013    590
2014    612
2015    675
2016    582
2017    590
2018    676
2019    607
2020    574
dtype: int64
In [66]:
asia_data = asia_data[asia_data.Year.isin(range(2000,2021))].copy()

Handling missing values¶

I decided to use interpolation for filling in the missing values. In order to take into account the specific data patterns within each country, I decided to group my data by country and perform interpolation separately for each country.

In [67]:
# Making a copy of asia_data in order not to mess it up when playing around with the code 
df = asia_data.copy()
In [68]:
# Checking the number of missing values for each column
na_count = df.isna().sum()
na_cols = na_count[na_count > 0].index.tolist()
In [69]:
na_count
Out[69]:
Country                         0
Country Code                    0
Year                            0
infantMortalityRate             0
adolescentBirthRate             0
alcohol                       791
atLeastBasicSanitation         21
basicDrinkingWaterServices      4
basicHandWashing              679
birthedByHealthstaff          454
gdpSpentHealthcare             65
nurseAndMidwife               376
physicians                    342
safelySanitation              362
tobaccoAge15                  721
Region                          0
Income group                    0
dtype: int64

Interpolate can fill missing values only for columns that have some non-NA values. For filling the remaining NAs, we can use the median value of each year.

Setting limit_direction='both' ensures that the interpolation is performed in both forward and backward directions, allowing missing values at the beginning and end of the data to be filled.

In [70]:
for col in na_cols:
    df[col] = df.groupby('Country')[col].transform(lambda x: x.interpolate(method='linear', limit_direction = 'both'))
    
    if df[col].isna().sum != 0:
        df[col] = df.groupby('Year')[col].transform(lambda x: x.fillna(x.median()))
In [71]:
df.isna().sum()
Out[71]:
Country                       0
Country Code                  0
Year                          0
infantMortalityRate           0
adolescentBirthRate           0
alcohol                       0
atLeastBasicSanitation        0
basicDrinkingWaterServices    0
basicHandWashing              0
birthedByHealthstaff          0
gdpSpentHealthcare            0
nurseAndMidwife               0
physicians                    0
safelySanitation              0
tobaccoAge15                  0
Region                        0
Income group                  0
dtype: int64
In [72]:
asia_data = df.copy()
asia_data.reset_index(drop = True, inplace = True)

Visualization¶

Infant Mortality Map¶

In [73]:
import seaborn as sns
import matplotlib.pyplot as plt
In [74]:
import geopandas as gp
import folium

# Load the country boundaries dataset
world = gp.read_file('G:/Documents/Projects/InfantMortality/metadata/ne_50m_admin_0_countries.shx')

# Filter the dataset for Asian countries
asian_countries = world[world.ISO_A3.isin(asia_data['Country Code'])]

avg_mortality = asia_data[['Country','infantMortalityRate']].groupby('Country').mean().round(2)
avg_mortality = avg_mortality.merge(asia_data[['Country Code','Country']].drop_duplicates(subset='Country'), on = 'Country')

merged_data = asian_countries.merge(avg_mortality, left_on='ISO_A3', right_on='Country Code')
merged_data = merged_data[['Country','infantMortalityRate', 'geometry']]
In [75]:
# Create a folium map centered around Asia
m = folium.Map(location=[35, 100], zoom_start=3)

# Add choropleth layer for infant mortality rate
folium.Choropleth(
    geo_data=merged_data,
    name='Infant Mortality Rate',
    data=merged_data,
    columns=['Country', 'infantMortalityRate'],
    key_on='feature.properties.Country',
    fill_color='YlOrRd',
    fill_opacity=0.7,
    line_opacity=0.2,
    legend_name='Infant Mortality Rate',
).add_to(m)

# Add country name and infant mortality rate as tooltips
for _, row in merged_data.iterrows():
    country = row['Country']
    infant_mortality = row['infantMortalityRate']
    tooltip = f"{country}: {infant_mortality}"
    folium.Marker([row['geometry'].centroid.y, row['geometry'].centroid.x], tooltip=tooltip).add_to(m)

# Display the map
m
Out[75]:
Make this Notebook Trusted to load map: File -> Trust Notebook

Line plot of infant mortality through the years¶

In [76]:
import plotly.graph_objects as go

countries = asia_data['Country'].unique()

traces = []

# Iterate over each country and create a trace for the line plot
for country in countries:
    # Filter the data for the current country
    data = asia_data[asia_data['Country'] == country]
    
    # Create a trace for the current country
    trace = go.Scatter(
        x=data['Year'],
        y=data['infantMortalityRate'],
        name=country,
        line=dict(width=2)
    )
    
    # Add the trace to the list of traces
    traces.append(trace)

# Create the layout for the plot
layout = go.Layout(
    title='Infant Mortality Rate by Country from 2000 to 2020',
    xaxis=dict(title='Year'),
    yaxis=dict(title='Infant Mortality Rate'),
    showlegend=True,
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    label=country,
                    method='update',
                    args=[{'visible': [True if trace.name == country else False for trace in traces]}],
                )
                for country in countries
            ]) +
            [
                dict(
                    label='Show All',
                    method='update',
                    args=[{'visible': [True] * len(traces)}],
                )
            ],
            direction='down',
            showactive=True,
            x=1,
            xanchor='right',
            y=1.2,
            yanchor='top',
        )
    ]
)

fig = go.Figure(data=traces, layout=layout)

fig.show()

We observe a declining trend in infant mortality rates spanning 21 years of our dataset, punctuated by occasional surges that can be attributed to specific circumstances.

For instance, there was a notable surge in infant mortality in Myanmar in 2008, coinciding with the devastating impact of Cyclone Nargis, which made landfall on May 3, 2008, in Burma (Myanmar). In the case of Syria, the catalyst was the outbreak of the civil war on March 15, 2011. Sri Lanka experienced two spikes in infant mortality rates in the years 2004 and 2009. In 2004, the country was ravaged by a tsunami, and in 2009, the escalation of the civil war, coupled with multiple bombings, contributed to the observed increase.

Countries that achieved the most significant reductions in infant mortality rates¶

In [77]:
asia_data_grouped = asia_data.groupby('Country')['infantMortalityRate']
# Calculate the overall change in infant mortality rate for each country
overall_mortality_change = asia_data_grouped.last() - asia_data_grouped.first()

# Sort countries based on overall change
countries_best_improvement = overall_mortality_change.sort_values(ascending=True)

top_10_countries = countries_best_improvement.head(10)

print(top_10_countries)
Country
Cambodia      -57.1
Afghanistan   -45.8
Azerbaijan    -43.7
Timor-Leste   -42.7
Lao PDR       -41.2
India         -39.9
Tajikistan    -39.1
Bangladesh    -39.0
Uzbekistan    -37.5
Mongolia      -35.6
Name: infantMortalityRate, dtype: float64
In [78]:
# Filter the data for the top 10 countries
top_10_data = asia_data[asia_data['Country'].isin(top_10_countries.index)]

# Pivot the data to have rates for 2000 and 2020
pivot_data = top_10_data.pivot(index='Country', columns='Year', values='infantMortalityRate')

# Create a bar plot for the years 2000 and 2020
pivot_data[[2000, 2020]].plot(kind='bar', figsize=(12, 6))
plt.xlabel('Country')
plt.ylabel('Infant Mortality Rate')
plt.title('Top 10 Countries with the Best Improvement in Infant Mortality Rate (2000-2020)')
plt.xticks(rotation=45)
plt.legend(title='Year')
plt.tight_layout()
plt.show()
No description has been provided for this image

GDP and Infant Mortality¶

In [79]:
# Group the data by 'Country' and calculate the mean of 'gdpSpentHealthcare' and 'infantMortalityRate'
df_grouped = asia_data.groupby('Country')[['gdpSpentHealthcare', 'infantMortalityRate']].mean().reset_index()

plt.figure(figsize=(12, 6))
ax = sns.barplot(data=df_grouped, x='Country', y='gdpSpentHealthcare', errorbar=None)
plt.xlabel('Country')
plt.ylabel('% GDP spent on healthcare')
plt.title('% of GDP spent on healthcare and Average Infant Mortality Rate for Each Country in Asia')
plt.xticks(rotation=90)

# Add average infant mortality rate labels above the bars
for index, row in df_grouped.iterrows():
    ax.text(index, row['gdpSpentHealthcare'], f"{row['infantMortalityRate']:.2f}", ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()
No description has been provided for this image

Surprisingly, although Afghanistan has been spending the highest % of its GDP on healthcare, infant mortality is still high.

Access to basic and safe sanitation facilities¶

In [80]:
basic_sanit = asia_data.groupby('Country')['atLeastBasicSanitation'].mean().reset_index()
safe_sanit = asia_data.groupby('Country')['safelySanitation'].mean().reset_index()

plt.figure(figsize=(12, 6))
sns.barplot(data=basic_sanit, x='Country', y='atLeastBasicSanitation', alpha=0.6, label = 'atLeastBasicSanitation')
sns.barplot(data=safe_sanit, x='Country', y='safelySanitation', alpha=0.6, saturation=1, label = 'safelySanitation')
plt.xlabel('Country')
plt.ylabel('Percentage')
plt.title('People using at least basic sanitation services (% of population), averaged over the years for each country')
plt.xticks(rotation=90)
plt.tight_layout()
plt.grid(axis='y')
plt.legend()
plt.show()
No description has been provided for this image

Bar plot of average infant mortality for each income group¶

In [81]:
# Group the data by 'Income group' and calculate the mean of 'infant mortality'
df_grouped = asia_data.groupby('Income group')['infantMortalityRate'].mean() \
            .reset_index().sort_values(by = 'infantMortalityRate', ascending = False)

# Define a colormap
cmap = sns.color_palette("YlOrRd", len(asia_data['Income group'].unique()))
sns.barplot(data=df_grouped, x='Income group', y='infantMortalityRate', errorbar = None, palette=cmap[::-1])
plt.ylabel('Infant mortality rate')
plt.title('Infant mortality rate for each income group of Asian countries')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()
No description has been provided for this image

Bar plot of average infant mortality for each region¶

In [82]:
# Group the data by 'Region' and calculate the mean of 'infant mortality'
df_grouped = asia_data.groupby('Region')['infantMortalityRate'].mean() \
            .reset_index().sort_values(by = 'infantMortalityRate', ascending = False)

# Define a colormap
cmap = sns.color_palette("tab10", len(asia_data['Region'].unique()))
sns.barplot(data=df_grouped, x='Region', y='infantMortalityRate', errorbar = None, palette=cmap[::-1])
plt.ylabel('Infant mortality rate')
plt.title('Infant mortality rate for each region in Asia')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()
No description has been provided for this image

Correlation matrix heatmap¶

In [83]:
# Select only numeric columns for correlation calculation
numeric_columns = asia_data.select_dtypes(include='number')

numeric_columns.drop('Year', axis = 1, inplace = True)

# Calculate the correlation matrix
correlation_matrix = numeric_columns.corr()

# Create a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', linewidths=0.5)
plt.title('Correlation Heatmap')
plt.show()
No description has been provided for this image

Hypothesis Testing¶

One-Sample T-Test¶

$H_0$: There is no significant difference in mean infant mortality rates in Asian countries and the global average.

$H_1$: There is a significant difference in mean infant mortality rates in Asia and other parts of the world.

In [84]:
import scipy.stats as stats

# Defining the significance level
alpha = 0.05

t_statistic, p_value = stats.ttest_1samp(asia_data.infantMortalityRate, global_avg_mortality)

# Print the results
print("T-Statistic:", t_statistic)
print("P-Value:", p_value)

# Checking if the p-value is less than the significance level to reject or fail to reject the null hypothesis
if p_value < alpha:
    print("Reject the null hypothesis. Mean value of infant mortality rates for Asian countries is significantly different from the global average.")
else:
    print("Failed to reject the null hypothesis. There is no significant difference in infant mortality rates in Asia and other parts of the world.")
T-Statistic: -41.51462329342729
P-Value: 5.318897758944627e-222
Reject the null hypothesis. Mean value of infant mortality rates for Asian countries is significantly different from the global average.

Two-Sample T-Test¶

Comparing infant mortality rates between countries with high income and countries of other income categories.

$H_0$: The mean infant mortality rate is the same between countries with high income and countries with other income categories.

$H_1$: The mean infant mortality rate is significantly higher in countries with other income categories compared to countries with high income.

In [85]:
high_income = asia_data[asia_data['Income group']=='High income'].infantMortalityRate
other_income = asia_data[asia_data['Income group']!='High income'].infantMortalityRate

# Perform the two-sample t-test
t_statistic, p_value = stats.ttest_ind(high_income, other_income)

# Print the results
print("T-Statistic:", t_statistic)
print("P-Value:", p_value)

if p_value < alpha:
    print('Reject the null hypothesis. The mean infant mortality rate is significantly higher in countries with other income categories compared to countries with high income.')
else:
    print('Failed to reject the null hypothesis.  The mean infant mortality rate is the same between countries with high income and countries with other income categories.')
        
T-Statistic: -19.629287053647694
P-Value: 4.261024079814796e-73
Reject the null hypothesis. The mean infant mortality rate is significantly higher in countries with other income categories compared to countries with high income.

ANOVA (Analysis of Variance)¶

Testing whether there are significant differences in infant mortalty rates among different regions of Asian countries.

$H_0$: There are no significant differences in the mean infant mortality rates among different regions of Asian countries.

$H_1$: There are significant differences in the mean infant mortality rates among different regions of Asian countries.

In [86]:
# Group the data by 'Region' and extract the 'infantMortalityRate' column
grouped_data = asia_data.groupby('Region')['infantMortalityRate']

alpha = 0.05

# Convert groups to NumPy arrays
group_arrays = [group.values for _, group in grouped_data]

# Perform the ANOVA test
f_statistic, p_value = stats.f_oneway(*group_arrays)

# Print the results
print("F-Statistic:", f_statistic)
print("P-Value:", p_value)

if p_value < alpha:
    print('Reject the null hypothesis.')
else:
    print('Failed to reject the null hypothesis.')
F-Statistic: 63.23458406012437
P-Value: 8.41301886779941e-48
Reject the null hypothesis.

Fitting a model¶

The Performance Metrics¶

  • R-squared (R2): This tells us how well our model explains the variance in the target variable. Higher values (close to 1) indicate a better fit.

  • Mean Squared Error (MSE): This measures the average squared difference between predicted and actual values. Lower values are better.

  • Mean Absolute Error (MAE): This represents the average absolute difference between predicted and actual values. Lower MAE values indicate a better model.

  • Root Mean Squared Error (RMSE): It is a measure of the average prediction error in the same units as your target variable. It quantifies how far, on average, your model's predictions deviate from the actual values. Lower RMSE values indicate a more accurate model.

In [87]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import numpy as np

Preprocessing¶

In [88]:
columns_to_encode = ["Country", "Region", "Year"]
one_hot_encoder = OneHotEncoder(sparse_output=False, drop='first', handle_unknown = 'ignore')
one_hot_encoded = one_hot_encoder.fit_transform(asia_data[columns_to_encode])

# Create a new dataframe with the one-hot encoded columns
one_hot_encoded_df = pd.DataFrame(
    one_hot_encoded,
    columns=one_hot_encoder.get_feature_names_out(columns_to_encode)
)

# Label encode the "Income group" column
label_encoder = LabelEncoder()
label_encoded = label_encoder.fit_transform(asia_data["Income group"])

# Add the one-hot encoded and label encoded columns to the original dataframe
encoded_columns = one_hot_encoded_df.columns.tolist() + ["Income group"]
encoded_data = pd.concat([one_hot_encoded_df, pd.Series(label_encoded, name="Income group")], axis=1)

# Update the "asia_data" dataframe with the encoded columns
asia_data_encoded = pd.concat([asia_data.drop(columns=["Country", "Region", "Income group", "Country Code", 'Year']),
                               encoded_data], axis=1)
In [89]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
In [90]:
# Split the data into features (X) and target variable (y)
X = asia_data_encoded.drop("infantMortalityRate", axis=1)
y = asia_data_encoded["infantMortalityRate"]

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

Linear Regression¶

In [91]:
# Fit the model on the training data
model_linreg.fit(X_train, y_train)

# Make predictions on the testing set
y_pred = model_linreg.predict(X_test)

# Calculate R-squared score
r2_linreg = r2_score(y_test, y_pred)

# Calculate mean squared error
mse_linreg = mean_squared_error(y_test, y_pred)

# Calculate mean absolute error
mae_linreg = mean_absolute_error(y_test, y_pred)

# Print the evaluation metrics
print("R-squared:", r2_linreg)
print("Mean Squared Error:", mse_linreg)
print("Root Mean Squared Error:", np.sqrt(mse_linreg))
print("Mean Absolute Error:", mae_linreg)
R-squared: 0.9690690011305333
Mean Squared Error: 10.323145389268479
Root Mean Squared Error: 3.212965202000868
Mean Absolute Error: 2.370692908843419

XGBoost¶

In [92]:
import xgboost as xgb

model = xgb.XGBRegressor(random_state = 42)

# Train the model
model.fit(X_train, y_train)

y_pred_xgb = model.predict(X_test)

r2_xgb = r2_score(y_test, y_pred_xgb)
mse_xgb = mean_squared_error(y_test, y_pred_xgb)
mae_xgb = mean_absolute_error(y_test, y_pred_xgb)

print("R-squared:", r2_xgb)
print("Mean Squared Error:", mse_xgb)
print("Root Mean Squared Error:", np.sqrt(mse_xgb))
print("Mean Absolute Error:", mae_xgb)
R-squared: 0.9861114650523148
Mean Squared Error: 4.635264645475793
Root Mean Squared Error: 2.1529664756971467
Mean Absolute Error: 1.352110361744285

XGBoost with hyperparameter tuning¶

In [93]:
from sklearn.model_selection import GridSearchCV

# Define the parameter grid to search through
param_grid = {
    'learning_rate': [0.1, 0.01, 0.001],
    'max_depth': [3, 5, 7],
    'n_estimators': [100, 200, 300]
}

# Perform grid search with cross-validation
grid_search = GridSearchCV(model, param_grid, cv=5)
grid_search.fit(X_train, y_train)

# Get the best hyperparameters
best_params = grid_search.best_params_
In [94]:
best_params
Out[94]:
{'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 300}
In [95]:
# Train the model with the best hyperparameters
best_model = xgb.XGBRegressor(**best_params)
best_model.fit(X_train, y_train)

# Make predictions on the test set
y_pred = best_model.predict(X_test)

r2_xgb_tuned = r2_score(y_test, y_pred)
mse_xgb_tuned = mean_squared_error(y_test, y_pred)
mae_xgb_tuned = mean_absolute_error(y_test, y_pred)

print("R-squared:", r2_xgb_tuned)
print("Mean Squared Error:", mse_xgb_tuned)
print("Root Mean Squared Error:", np.sqrt(mse_xgb_tuned))
print("Mean Absolute Error:", mae_xgb_tuned)
R-squared: 0.9884233112087157
Mean Squared Error: 3.863691632561967
Root Mean Squared Error: 1.9656275416675375
Mean Absolute Error: 1.19585860562556
In [96]:
importances = best_model.feature_importances_

feature_importance_df = pd.DataFrame({'Feature': X_train.columns, 'Importance': importances})
feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)

feature_importance_df.head(12)
Out[96]:
Feature Importance
83 Income group 0.357954
5 birthedByHealthstaff 0.329962
4 basicHandWashing 0.082619
41 Country_Pakistan 0.041620
3 basicDrinkingWaterServices 0.034353
53 Country_Turkmenistan 0.022690
7 nurseAndMidwife 0.016145
10 tobaccoAge15 0.013912
55 Country_Uzbekistan 0.013654
39 Country_Nepal 0.011806
8 physicians 0.008024
2 atLeastBasicSanitation 0.007434

Cross validating XGBoost model with hyperparameters¶

In [97]:
from sklearn.model_selection import cross_val_score, KFold

kfold = KFold(n_splits=5, shuffle=True, random_state=42)

cv_scores = cross_val_score(best_model, X, y, cv=kfold, scoring='neg_mean_squared_error')

# Convert the negative MSE scores to positive and calculate RMSE
cv_rmse_scores = np.sqrt(-cv_scores)

# Calculate the mean and standard deviation of RMSE scores
mean_rmse = np.mean(cv_rmse_scores)
std_rmse = np.std(cv_rmse_scores)

print("Cross-Validation RMSE Scores:", cv_rmse_scores)
print("Mean RMSE:", mean_rmse)
print("Standard Deviation of RMSE:", std_rmse)
Cross-Validation RMSE Scores: [1.96039727 2.3057763  2.3338948  2.07144672 1.88021005]
Mean RMSE: 2.1103450261584067
Standard Deviation of RMSE: 0.1817291399043206